{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "<i>Copyright (c) Recommenders contributors.</i>\n",
                "\n",
                "<i>Licensed under the MIT License.</i>"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "inputHidden": false,
                "outputHidden": false
            },
            "source": [
                "# Movie recommender with multinomial RBM (Tensorflow, GPU)\n",
                "\n",
                "A Restricted Boltzmann Machine (RBM) is a generative neural network model typically used to perform unsupervised learning. The main task of an RBM is to learn the joint probability distribution $P(v,h)$, where $v$ are the visible units and $h$ the hidden ones. The hidden units represent latent variables while the visible units are clamped on the input data. Once the joint distribution is learnt, new examples are generated by sampling from it.  \n",
                "\n",
                "In this notebook, we provide an example of how to utilize the RBM to perform user/item recommendations. In particular, we use as a case study the [movielens dataset](https://movielens.org), comprising user's ranking of movies on a scale of 1 to 5. \n",
                "\n",
                "This notebook provides a quick start, showing the basic steps needed to use and evaluate the algorithm. A detailed discussion of the RBM model together with a deeper analysis of the recommendation task is provided in the [RBM Deep Dive section](../02_model/rbm_deep_dive.ipynb). The RBM implementation presented here is based on the article by Ruslan Salakhutdinov, Andriy Mnih and Geoffrey Hinton [Restricted Boltzmann Machines for Collaborative Filtering](https://www.cs.toronto.edu/~rsalakhu/papers/rbmcf.pdf) with the exception that here we use multinomial units instead of the one-hot encoded used in the paper.  \n",
                "\n",
                "### Advantages of RBM: \n",
                "\n",
                "The model generates ratings for a user/movie pair using a collaborative filtering based approach. While matrix factorization methods learn how to reproduce an instance of the user/item affinity matrix, the RBM learns the underlying probability distribution. This has several advantages: \n",
                "\n",
                "- Generalizability : the model generalize well to new examples.\n",
                "- Stability in time: if the recommendation task is time-stationary, the model does not need to be trained often to accomodate new ratings/users. \n",
                "- The tensorflow implementation presented here allows fast training on GPU "
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## 0 Global Settings and Import"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 1,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/pradjoshi/anaconda3/envs/reco-env/lib/python3.7/site-packages/papermill/iorw.py:50: FutureWarning: pyarrow.HadoopFileSystem is deprecated as of 2.0.0, please use pyarrow.fs.HadoopFileSystem instead.\n",
                        "  from pyarrow import HadoopFileSystem\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "System version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:21) \n",
                        "[GCC 9.4.0]\n",
                        "Pandas version: 1.3.5\n",
                        "Tensorflow version: 2.7.0\n"
                    ]
                }
            ],
            "source": [
                "import sys\n",
                "import numpy as np\n",
                "import pandas as pd\n",
                "import tensorflow as tf\n",
                "tf.get_logger().setLevel('ERROR') # only show error messages\n",
                "\n",
                "from recommenders.models.rbm.rbm import RBM\n",
                "from recommenders.datasets.python_splitters import numpy_stratified_split\n",
                "from recommenders.datasets.sparse import AffinityMatrix\n",
                "from recommenders.datasets import movielens\n",
                "from recommenders.evaluation.python_evaluation import map_at_k, ndcg_at_k, precision_at_k, recall_at_k\n",
                "from recommenders.utils.timer import Timer\n",
                "from recommenders.utils.plot import line_graph\n",
                "from recommenders.utils.notebook_utils import store_metadata\n",
                "\n",
                "#For interactive mode only\n",
                "%load_ext autoreload\n",
                "%autoreload 2\n",
                "%matplotlib inline\n",
                "\n",
                "print(f\"System version: {sys.version}\")\n",
                "print(f\"Pandas version: {pd.__version__}\")\n",
                "print(f\"Tensorflow version: {tf.__version__})"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# 1 Load Data \n",
                "\n",
                "Here we select the size of the movielens dataset. In this example we consider the 100k ratings datasets, provided by  943 users on 1682 movies. The data are imported in a pandas dataframe including the user ID, the item ID, the ratings and a timestamp denoting when a particular user rated a particular item.  "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "metadata": {
                "tags": [
                    "parameters"
                ]
            },
            "outputs": [],
            "source": [
                "# Select MovieLens data size: 100k, 1m, 10m, or 20m\n",
                "MOVIELENS_DATA_SIZE = '100k'"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|██████████| 4.81k/4.81k [00:00<00:00, 30.9kKB/s]\n"
                    ]
                },
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>userID</th>\n",
                            "      <th>movieID</th>\n",
                            "      <th>rating</th>\n",
                            "      <th>timestamp</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>196</td>\n",
                            "      <td>242</td>\n",
                            "      <td>3.0</td>\n",
                            "      <td>881250949</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>186</td>\n",
                            "      <td>302</td>\n",
                            "      <td>3.0</td>\n",
                            "      <td>891717742</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>22</td>\n",
                            "      <td>377</td>\n",
                            "      <td>1.0</td>\n",
                            "      <td>878887116</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>244</td>\n",
                            "      <td>51</td>\n",
                            "      <td>2.0</td>\n",
                            "      <td>880606923</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>166</td>\n",
                            "      <td>346</td>\n",
                            "      <td>1.0</td>\n",
                            "      <td>886397596</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "   userID  movieID  rating  timestamp\n",
                            "0     196      242     3.0  881250949\n",
                            "1     186      302     3.0  891717742\n",
                            "2      22      377     1.0  878887116\n",
                            "3     244       51     2.0  880606923\n",
                            "4     166      346     1.0  886397596"
                        ]
                    },
                    "execution_count": 3,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "data = movielens.load_pandas_df(\n",
                "    size=MOVIELENS_DATA_SIZE,\n",
                "    header=['userID','movieID','rating','timestamp']\n",
                ")\n",
                "\n",
                "data.head()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 1.2 Split the data using the stratified splitter  \n",
                "\n",
                "As a second step we generate the user/item affiity matrix and then split the data into train and test set. If you are familiar with training supervised learning model, here you will notice the first difference. In the former case, we cut off a certain proportion of training examples from dataset (e.g. images), here corresponding to users (or items), ending up with two matrices (train and test) having different row dimensions. Here we need to mantain the same matrix size for the train and test set, but the two will contain different amounts of ratings, see the [deep dive notebook](../02_model/rbm_deep_dive.ipynb) for more details. The affinity matrix reads     "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "metadata": {
                "inputHidden": false,
                "outputHidden": false,
                "tags": [
                    "sparse_matrix"
                ]
            },
            "outputs": [],
            "source": [
                "#to use standard names across the analysis \n",
                "header = {\n",
                "        \"col_user\": \"userID\",\n",
                "        \"col_item\": \"movieID\",\n",
                "        \"col_rating\": \"rating\",\n",
                "    }\n",
                "\n",
                "#instantiate the sparse matrix generation  \n",
                "am = AffinityMatrix(df = data, **header)\n",
                "\n",
                "#obtain the sparse matrix \n",
                "X, _, _ = am.gen_affinity_matrix()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "The method also returns informations on the sparsness of the dataset and the size of the user/affinity matrix. The former is given by the ratio between the unrated elements and the total number of matrix elements. This is what makes a recommendation task hard: we try to predict 93% of the missing data with only 7% of information!\n",
                "\n",
                "We split the matrix using the default ration of 0.75, i.e. 75% of the ratings will constitute the train set."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "metadata": {
                "tags": [
                    "split"
                ]
            },
            "outputs": [],
            "source": [
                "Xtr, Xtst = numpy_stratified_split(X)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "The splitter returns:\n",
                "\n",
                "- Xtr: a matrix containing the train set ratings \n",
                "- Xtst: a matrix containing the test elements \n",
                "\n",
                "Note that the train/test matrices have exactly the same dimension, but different entries as it can be explicitly verified:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "train matrix size (943, 1682)\n",
                        "test matrix size (943, 1682)\n"
                    ]
                }
            ],
            "source": [
                "print('train matrix size', Xtr.shape)\n",
                "print('test matrix size', Xtst.shape)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "tags": [
                    "model",
                    "train"
                ]
            },
            "source": [
                "## 2 Train the RBM model\n",
                "\n",
                "The model has been implemented as a Tensorflow (TF) class. TF does not support probabilistic models natively, so the implementation of the algorithm has a different structure than the one you may be used to see in popular supervised models. The class has been implemented in such a way that the TF session is hidden inside the `fit()` method and no explicit call is needed. The algorithm operates in three different steps: \n",
                "\n",
                "- Model initialization: This is where we tell TF how to build the computational graph. The main parameters to specify are the number of hidden units, the number of training epochs and the minibatch size. Other parameters can be optionally tweaked for experimentation and to achieve better performance, as explained in the [RBM Deep Dive section](../02_model/rbm_deep_dive.ipynb).\n",
                "\n",
                "- Model fit: This is where we train the model on the data. The method takes two arguments: the training and test set matrices. Note that the model is trained **only** on the training set, the test set is used to display the generalization accuracy of the trained model, useful to have an idea of how to fix the hyper parameters. \n",
                "\n",
                "- Model prediction: This is where we generate ratings for the unseen items. Once the model has been trained and we are satisfied with its overall accuracy, we sample new ratings from the learned distribution. In particular, we extract the top_k (e.g. 10) most relevant recommendations according to some predefined score. The prediction is then returned in a dataframe format ready to be analysed and deployed.  "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 7,
            "metadata": {
                "inputHidden": false,
                "outputHidden": false,
                "tags": [
                    "initialization"
                ]
            },
            "outputs": [],
            "source": [
                "#First we initialize the model class\n",
                "model = RBM(\n",
                "    possible_ratings=np.setdiff1d(np.unique(Xtr), np.array([0])),\n",
                "    visible_units=Xtr.shape[1],\n",
                "    hidden_units=600,\n",
                "    training_epoch=30,\n",
                "    minibatch_size=60,\n",
                "    keep_prob=0.9,\n",
                "    with_metrics=True\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Note that the first time the fit method is called it may take longer to return the result. This is due to the fact that TF needs to initialized the GPU session. You will notice that this is not the case when training the algorithm the second or more times.   "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 8,
            "metadata": {
                "inputHidden": false,
                "outputHidden": false,
                "tags": [
                    "training"
                ]
            },
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Took 2.49 seconds for training.\n"
                    ]
                },
                {
                    "data": {
                        "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVEAAAE9CAYAAACyQFFjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAArsUlEQVR4nO3deXxU9b3/8dcnCwkQSAIkQNh3ZA0aRZEqbrggoFa94vKz2mptq7a29Vbba2lre9tb297ettZqW0RbxbaCFbV1LYIsimHf9y2sYd8JhM/vjxlshGyTzORMMu/n4zEPknPOzHwykDfne77f8/2auyMiIjWTFHQBIiL1mUJURKQWFKIiIrWgEBURqQWFqIhILShERURqISXoAqKpVatW3rlz56DLEJEGZs6cOTvdPae8fQ0qRDt37kxhYWHQZYhIA2NmGyrap+a8iEgtxDREzWycme0ws8UV7M80s9fMbIGZLTGzu8rsu9PMVoUfd8ayThGRmor1meh44KpK9n8FWOruA4FhwM/NrJGZtQDGAoOB84CxZpYd41pFRCIW0xB192nA7soOAZqZmQEZ4WNPAFcC77j7bnffA7xD5WEsIhKIoDuWfgNMBrYAzYD/cPeTZtYO2FTmuCKgXQD1iYhUKuiOpSuB+UAekA/8xsyaR/ICZnavmRWaWWFxcXH0KxQRqUTQIXoXMMlDVgPrgN7AZqBDmePah7edwd2fcfcCdy/IySl3GJeISMwEHaIbgcsAzKw10AtYC7wFDDez7HCH0vDwNhGRuBLTa6JmNoFQr3srMysi1OOeCuDuvwMeB8ab2SLAgG+5+87wcx8HPg6/1A/cvbIOKhGRQMQ0RN19TBX7txA6yyxv3zhgXCzqOuXYiVJ2HyqhbWbjWL6NiDRgQTfnAzVleTHfmrgo6DJEpB5L6BC9oFtL5qzfzdHjpUGXIiL1VEKHaGbjVHq0bsbcDXuCLkVE6qmEDlGAz/RoxQerdwZdhojUUwkfokO7t2KGQlREaijhQ3RQx2zWFR9iz6GSoEsRkXoo4UO0UUoSBZ2zmblmV9CliEg9lPAhCjC0Rw7TV+u+exGJnEKUcOfSqp24e9CliEg9oxAFeuRmcLz0JBt3Hw66FBGpZxSigJlxYffQ2aiISCQUomFDu7diukJURCKkEA0b2r0Vs9buovSkrouKSPUpRMNym6fTpnk6C4v2Bl2KiNQjCtEyhvZQk15EIqMQLWNoj1ZM1y2gIhIBhWgZg7u0YPHmfRw6diLoUkSknlCIltGkUQr92mUye51WIhGR6lGInubU3UsiItWhED2N7qMXkUgoRE/Tv10m2/cfY/v+o0GXIiL1gEL0NMlJxgVdW2qiZhGpFoVoOTReVESqSyFajs+Ex4tqajwRqYpCtBydWjYlLTWJldsPBl2KiMQ5hWgFhnbP4YNV6qUXkcopRCugVUBFpDoUohUY0q0lH6/fQ8mJk0GXIiJxTCFageymjeia05S5G/cEXYqIxDGFaCU0272IVEUhWomhPVrxga6LikglFKKVOKdTNqu3H2Dv4ZKgSxGROKUQrURaSjKf6ZHDW0u2BV2KiMQphWgVRuXnMXnBlqDLEJE4FdMQNbNxZrbDzBZXsP9hM5sffiw2s1IzaxHe95CZLQlvn2Bm6bGstSKX9s5lUdE+dhzQrE4icqZYn4mOB66qaKe7P+Hu+e6eDzwKTHX33WbWDngQKHD3fkAycEuMay1Xemoyl5/VmjcWbg3i7UUkzsU0RN19GlDdtTbGABPKfJ8CNDazFKAJEFibeqSa9CJSgbi4JmpmTQidsU4EcPfNwM+AjcBWYJ+7vx1UfUO7t2LDrsNs2n04qBJEJE7FRYgCI4EZ7r4bwMyygdFAFyAPaGpmt5f3RDO718wKzaywuDg2E4akJidxdb82OhsVkTPES4jewqeb8pcD69y92N2PA5OAIeU90d2fcfcCdy/IycmJWYGjBubxmkJURE4TeIiaWSZwMfBqmc0bgfPNrImZGXAZsCyI+k45t3ML9h05zsrtB4IsQ0TiTKyHOE0AZgG9zKzIzD5vZveZ2X1lDrseeNvdD53a4O4fAS8Dc4FF4TqfiWWtVUlKMq4d0JbJ83U2KiL/Zg1pCYyCggIvLCyM2esvKtrHV16cy9SHhxE6QRaRRGBmc9y9oLx9gTfn65N+7ZqTnGQsKNoXdCkiEicUohEwM0YOzFOTXkQ+oRCN0KiBeby+cAulJxvOZRARqTmFaIS652bQKiONj9btCroUEYkDCtEaGJWvMaMiEqIQrYGRA/N4c/E2LWInIgrRmmiX1ZhuORlal15EFKI1pcmaRQQUojV2Tf+2/Gv5Do6UlAZdiogESCFaQ60y0sjvkMV7y7cHXYqIBEghWgujNPBeJOEpRGvhyn5tmLVmF/uOHA+6FBEJiEK0FpqnpzKke0stqSySwBSitTS8TxumrtRQJ5FEpRCtpYEdslhYtDfoMkQkIArRWuraqil7Dx1n96GSoEsRkQAoRGspKcno1y6TBTobFUlICtEoGNghi4WbNFGzSCJSiEbBwPaZui4qkqAUolEwoEMWC4r20ZDWqxKR6lGIRkFeZjrgbN13NOhSRKSOKUSjwMwY2D6LBZv2Bl2KiNQxhWiUDGifpVVARRKQQjRKBnRQ55JIIlKIRsnA9lks2ryPk1oFVCShKESjpEXTRmQ2TmXtzkNBlyIidUghGkW6j14k8ShEoyg06F6dSyKJRCEaRaEe+r1BlyEidUghGkX92mWyfOsBjpdqPXqRRKEQjaKMtBQ6tGjMim0Hgi5FROqIQjTK1KQXSSwK0Sgb2D5T0+KJJBCFaJTpTFQkscQ0RM1snJntMLPFFex/2Mzmhx+LzazUzFqE92WZ2ctmttzMlpnZBbGsNVp6t23Ghl2HOVxyIuhSRKQOxPpMdDxwVUU73f0Jd89393zgUWCqu+8O7/4/4E137w0MBJbFuNaoSEtJpmfrDJZs2R90KSJSB2Iaou4+Ddhd5YEhY4AJAGaWCVwE/DH8OiXuvjcWNcbCAE2LJ5Iw4uKaqJk1IXTGOjG8qQtQDDxrZvPM7A9m1jSwAiM0QHcuiSSMuAhRYCQwo0xTPgU4G3jK3QcBh4BHynuimd1rZoVmVlhcXFw31VYhv4M6l0QSRbyE6C2Em/JhRUCRu38U/v5lQqF6Bnd/xt0L3L0gJycnxmVWT9ecDHYdLGHvYa1FL9LQBR6i4eufFwOvntrm7tuATWbWK7zpMmBpAOXVSHKS0TevuZr0IgkgJZYvbmYTgGFAKzMrAsYCqQDu/rvwYdcDb7v76RNxPgC8YGaNgLXAXbGsNdpOTYt3Uc/4ODsWkdiIaYi6+5hqHDOe0FCo07fPBwqiXlQdGdA+k1fnbwm6DBGJscCb8w2VVv8USQwK0Rhpn92YEyedbVqLXqRBU4jGiJkxoH2mhjqJNHAK0Rga0F5rLok0dArRGMrvkMkCTYsn0qApRGPo1Jmou9aiF2moFKIx1CojjWbpqazfdTjoUkQkRhSiMRaajGRv0GWISIwoRGNsYIcs5mu8qEiDpRCNMU2LJ9KwKURjrH+7TJZt3c/M1TuDLkVEYkAhGmPN0lP5+U0DeWTSIv7fuNks3qyzUpGGRCFaB67u35Z3v34xl/XO5a7xH/PghHlsVI+9SIOgEK0jjVKSuHNIZ97/5jC65WQw+snpjH11MTsPHgu6NBGpBavuQHAzSwM+C3SmzBR67v6DmFRWAwUFBV5YWBh0GdWy6+AxfjNlNa/M28ydF3Tmixd3pUmjmM5MKCI1ZGZz3L3cqTkjORN9FRgNnCC05tGph9RAy4w0xo7sy2v3D2Vh0V5++uaKoEsSkRqI5NSnvbtXuIa81EyHFk347si+3Pz0LB67tg/JSRZ0SSISgUjORGeaWf+YVZLAurRqSpvm6Xy4dlfQpYhIhCIJ0aHAHDNbYWYLzWyRmS2MVWGJZtTAPF5boOVEROqbSEL0aqAHMJzQOvHXhv+UKBgxoC1vLtlGyYmTQZciIhGoMkTNrHn4ywMVPCQK8rIa0zO3GR+sKg66FBGJQHU6ll4kdNY5B3CgbM+HA11jUFdCGjmwLZMXbOGys1oHXYqIVFOVIeru14b/7BL7chLb1f3b8tO3VnCkpJTGjZKDLkdEqiGiO5bMLNvMzjOzi049YlVYImqVkUZ+hyzeW7496FJEpJqqHaJm9gVgGvAW8P3wn9+LTVmJS730IvVLJGeiXwXOBTa4+yXAIGBvLIpKZMP7tmHm6l3sP3o86FJEpBoiCdGj7n4UQvfRu/tyoFdsykpcmY1TuaBbS95eoia9SH0QSYgWmVkW8HfgHTN7FdgQi6IS3ciBeUxWk16kXqh2iLr79e6+192/BzwG/BG4LkZ1JbTLzspl3sY97NI0eSJxr1ohambJZrb81PfuPtXdJ7t7SexKS1xNGqVwSa9c/rF4W9CliEgVqhWi7l4KrDCzjjGuR8LUSy9SP0RyTTQbWGJm75nZ5FOPWBWW6D7TsxUrtx9g674jQZciIpWIZD7Rx2JWhZwhLSWZK/u04Y2FW/nCZ3RnrUi8iuRM9JrwtdBPHsA1sSpM1EsvUh9EEqJXlLPt6sqeYGbjzGyHmS2uYP/DZjY//FhsZqVm1qLM/mQzm2dmr0dQZ4NxQbeWbNl7lPU7tQqLSLyqzlR4XzKzRUCv8GTMpx7rgKomZR4PVLikiLs/4e757p4PPApMdffdZQ75KrCsqhobquQkY0T/NupgEolj1TkTfZHQ5MuTw3+eepzj7refOsjMsk9/ortPA3afvr0CY4AJZV6vPTAC+EM1n98gjcrP47WFClGReFVliLr7Pndf7+5j3H1Dmcfp4fheTYswsyaEzlgnltn8S+A/gYSe6n1Qh2wOHStl+bb9QZciIuWIaCq8KtRmmcqRwIxTwWxm1wI73H1OlW9qdq+ZFZpZYXFxw5sVPinJuHZgWzXpReJUNEPUa/HcWyjTlAcuBEaZ2XrgJeBSM/tzuW/q/oy7F7h7QU5OTi1KiF8jB+Qxcc5mNu0+HHQpInKaaIZojZhZJnAx8Oqpbe7+qLu3d/fOhAL2X2Wvvyaafu0yueeirlz35Axenb856HJEpIxIBttX5YzmvJlNAIYBrcysCBgLpAK4++/Ch10PvO3uGsdTic8P7cLgLi148KV5TFu5k++P7ktGWjT/+kSkJsy9+q1wMxsK9HD3Z80sB8hw93XhfS3K6WyqUwUFBV5YWBhkCTF3uOQE35+8lI/W7eJXYwYxoH1W0CWJNHhmNsfdC8rbF8nyIGOBbxEazwmhM8pPrlMGHaCJokmjFP7nxgF888pe3PXsxzw9dQ0nT9bmcrSI1EYk10SvB0YBhwDcfQvQLBZFSdWuHZDHq/dfyDtLt3Pns7PZsf9o0CWJJKRIQrTEQ21/BzCzprEpSaqrfXYTXrr3fM7umM2IX0/n2RnrKD6giZxF6lIkPRN/NbOngSwzuwe4G/h9bMqS6kpJTuKhK3oyrFcOf5q1gV+8s5L8DlmMzm/HlX1b0yw9NegSRRq0SDuWrgCGE+qJf8vd34lVYTWRCB1LVTlSUsq7y7bz6vzNfLR2Nxf1zGF0fh4X98ohLSU56PJE6qXKOpaqHaLh5vtRdy81s16EVvr8p7vHzdq+CtFP23OohH8u3sbf529m5fYD3HlBZx66omfQZYnUO1HpnQemAWlm1g54E7iD0CxNEqeymzbi1sEd+esXL+D1B4byx+nr2Hckbv7PE2kQIglRc/fDwA3AU+5+E9A3NmVJtLXPbsK5nbOZvmpn0KWINCgRhaiZXQDcBrwR3qaLbPXIJb1zmbJiR9BliDQokYTo1wgNtH/F3ZeYWVdgSkyqkpgY1jOXqSuLNThfJIqqPcQpvKbS1DLfrwUejEVREhsdWzaheXoKS7fup1+7zKDLEWkQIrnts8DMJpnZ3LLLhMSyOIm+Yb1ymbJcTXqRaImkOf8Cod74z/LpZUKkHrmkl66LikRTJHcsFbv75JhVInXi3C7ZrNp+kN2HSmjRtFHQ5YjUe5GE6Fgz+wOhtZQ+uUHb3SdFvSqJmbSUZM7v1pIPVhUzOr9d0OWI1HuRhOhdQG9CU+CdWjzOAYVoPXNJ+LqoQlSk9iIJ0XPdvVfMKpE6M6xXDj97ewWlJ53kpNqsLygikXQszTSzPjGrROpMXlZjcpulsbBob9CliNR71QpRMzNCi8nNN7MV4eFNizTEqf66uFcOU1Y0vCWmRepatUI0PBlzLtCD0FR4I4Fr0RCneuuSXrm8r6FOIrUWyTXRiUCuu38cq2Kk7pzTKZv1Ow9RfOAYOc3Sgi5HpN6K5JroYGCWma1Rc77+S01OYmiPVkxdqSa9SG1EciZ6ZcyqkEAMC9+9dOM57YMuRaTeimQCkg2xLETq3rCeOfzojWWcKD1JSnIkjRIROUW/OQkst3k6HVo0Zu7GvUGXIlJvKUQTnHrpRWpHIZrghmm8qEitKEQTXH6HbLbuO8K2fUeDLkWkXlKIJrjkJOOiHjlq0ovUkEJUuKR3jiZqFqkhhahwUY8cZq7ZRcmJk1UfLCKfohAVWmak0S0ng8L1u4MuRaTeUYgKoLWXRGpKISpA6Lro+xrqJBKxmIaomY0zsx1mtriC/Q+b2fzwY7GZlZpZCzPrYGZTzGypmS0xs6/Gsk6BfnmZ7DlcoomaRSIUyQQkNTEe+A3wfHk73f0J4AkAMxsJPOTuu80sDfiGu881s2bAHDN7x92XxrjehJWUZDxwaQ8+/1whzdNTuPys1lzaO5dzOmXrvnqRSsQ0RN19mpl1rubhY4AJ4edtBbaGvz5gZsuAdoBCNIbuHNKZO87vxKLN+3hv2XZ+8PpSNu89wrCeOVx6Vmsu7plDZuPUoMsUiSsWmrQ+hm8QCtHX3b1fJcc0AYqA7u6++7R9nYFpQD9331/ZexUUFHhhYWGta5Z/27rvCO8t28F7y7bz8fo9XNi9JT++YYDWrJeEYmZz3L2gvH3x0k4bCcwoJ0AzCM2o/7WKAtTM7jWzQjMrLC5Wx0i0tc1szO3nd+LZu85j9ncuo0urDK791QfM3bgn6NJE4kK8hOgthJvyp5hZKqEAfcHdK1zb3t2fcfcCdy/IycmJcZmJrUmjFB65ujffH92Pe54rZPyMdcS6JSMS7wIPUTPLJLSS6KtlthnwR2CZu/8iqNqkfFf0ac2kLw/hr4VFPDBhHgePnQi6JJHAxHqI0wRgFtDLzIrM7PNmdp+Z3VfmsOuBt939UJltFwJ3AJeWGQJ1TSxrlch0atmUSV8eQkZaCqN/M52V2w8EXZJIIGLesVSX1LEUjL8VbuLH/1zO2JF9GJ3fLuhyRKKuso6lWI8TlQRwU0EH+uZl8uUX5lC4fg+PXduHRimBXykSqRP6ly5R0SevOZMfGMrWfUe5c9xs9h05HnRJInVCISpR0zw9lafvOIdebZpx41MzKdpzOOiSRGJOISpRlZxkfG9UX8ac15HPPjWTRUX7gi5JJKYUohITdw/twvdH9ePOZ2fzr+Xbgy5HJGbUsSQxc1W/NuQ2T+OLf5rDg5cd5Y7zOwVdkkjU6UxUYursjtm8fN8FPDt9HT/+xzJOnmw4Q+pEQCEqdaBTy6ZM/NIQ5mzYwwMvzePo8dKgSxKJGoWo1Inspo348xcGA3D/i/N0z700GApRqTPpqcn87835bNp9mNcWbg26HJGoUIhKnWqUksRPPtufx19fyp5DJUGXI1JrClGpc4M6ZjOif1t+9I9lQZciUmsKUQnEN6/sxaw1u5ixemfQpYjUikJUApGRlsLj1/Xl268siri3/khJaVzcm3+45ASlGrKV8BSiEphLe7emf7tMfvnuqmo/Z/v+o1z/2xnc9LuZHApwMmh3585xs3n8da2dmOgUohKosSP78rfCTSzZUvU99qt3HOCG385kVH4eA9tn8Z8TFwY2VOr9lcXsOljCawu2sHxbpesnSgOnEJVA5TRL41tX9eaRiYs4UXqywuMK1+/mlmc+5OtX9OTLw7rz+HX92LjrMH/4YF0dVhvi7vz87RV888pefO3yHnxv8hKNe01gClEJ3E0F7WmWnsL4mevL3f/Wkm188U9z+PnN+Xz2nPZAaMzpU7efzdPT1jJzTd12Tr21ZBvucFXfNow5ryN7Dx/nH4u21WkNEj8UohI4M+O/r+/Pk1NWs2n3p+cg/dOHG3js74sZf9d5XNzz06u5ts9uwi//I5+vvjSfLXuP1EmtpSedX7yzkm8M70lSkpGSnMT3RvXlv/+xjCMlup01ESlEJS50btWUey7qynf+vhh3x9154q3ljJu+jpfvG0L/9pnlPm9oj1bcfWEXvvTCXI6diH2IvbZgCxlpKVzSK/eTbed3bcmgjlk8NXVNzN9f4o9CVOLGPZ/pSvGBY0ycu5mHX17I9NW7ePm+C+jYskmlz7vv4q7kZabzvcmx7Sk/XnqSX767km8O70VoVe9/+/Y1Z/H8rPVnnElLw6cQlbiRmpzET27oz7cmLmT3oRIm3DOYlhlpVT7PzHjipoHMXreLl2ZvjFl9E+cUkZfVmCHdW52xLy+rMZ+/sAs/fENDnhKNQlTiysAOWfz9yxfyzB3n0KRR9ecMz0hL4ek7CvjpWytYsGlv1Os6dqKUX723im8M71XhMfdc1JVlWw/wwariqL+/xC+FqMSd/u0zSUmO/J9m99wM/vv6/nz5hbnsOngsqjW9NHsTvds255xO2RUek56azH+NOIvvv7aU45UM15KGRSEqDcpV/dowOj+P+1+cF7Xe8iMlpTw5ZTVfv6Jnlcde0ac1bTPTeX7Whqi8t8Q/hag0ON8Y3ou2Welc/9sZrNt5qNav9/ys9RR0zqZfu/JHCJRlZowd2Ycnp6xmZ5TPhiU+KUSlwUlOMn5+00BuP78TNz41kzcX13wC6ANHj/PMtLU8dHnVZ6GndM9txg2D2vHEmytq/L5SfyhEpUEyM24/vxPjPncuj7++jB+9UbPrlOOmr+finjn0aN0souc9eHkPpqzYEZNOLokvClFp0AZ2yOL1B4ayasdBbv39h2zff7Taz917uITxM9fx1ct7RPy+zdNTefjKXnzlxbmMfXUxL3y0gcL1u+NiCj+JLmtIEycUFBR4YWFh0GVIHDp50vnNlNX8+cMN/N8tg7igW8sqn/M/by5n7+Hj/PiG/jV6T3dn1tpdLN2yn5XbD7By+0FWbT9As/RUerZpRq/WGfRs3YwRA9pGNJxL6p6ZzXH3gnL3KUQlkXywqpiH/rKAu4d25u4Lu1B84Bjb9x9l+/5jbNt/lB37j7Jt/1G27z/K0i37efNrF5GX1Thq73/ypLN57xFW7TjAim0HmbpyBy0z0vjNmEFn3AUVLUV7DmNmtIviz5FoFKIiZWzZe4T7X5zLwqJ95DZLo3VmOq2bpdMmM53c5mm0aZ5O6+bpdM1pStvM2AbP0eOlXP/bmdw6uCN3nN8pKq9ZfOAYs9buYtaancxYvYu9h0tol92Efzw4NGZB3dBVFqJqQ0jCyctqzMQvDcEdkpKCDZX01GR+e9vZ3PjUTAZ1yKrWMKrTHTh6nFlrdjFzzS5mrdnF1n1HOK9LSy7s3pLPDelC99wMLv35+8zftJdBHSu+WUBqRiEqCcnMiJeTsi6tmjJ2VF/uf3Eurz0wlGbpqdV+7uLN+7hr/Mf0btOMC7q15Kc3DqBvXvMz7vi69byOvPDRRoVoDMS0d97MxpnZDjNbXMH+h81sfvix2MxKzaxFeN9VZrbCzFab2SOxrFMkaKMG5jGkeysenbSo2rPkz9mwh889O5vHR/flT58fzJeHdWdgh6xyb5m98Zz2vLVkG/sOa3RAtMV6iNN44KqKdrr7E+6e7+75wKPAVHffbWbJwJPA1UAfYIyZ9YlxrSKB+u61fVi94yAvVmMmqplrdnLP84U8cdNArurXtsrjW2akcUmvXCbNK4pGqVJGTEPU3acBu6t5+BhgQvjr84DV7r7W3UuAl4DRMShRJG6kpybz5G1n8/O3V7J0S8WL301ZvoMHXpzHk7ee/anJoaty2+BQk74hdSbHg7gYbG9mTQidsU4Mb2oHbCpzSFF4m0iD1i0ng7Ej+/CVF+dysJwlof+5aCsPv7yA399ZUK2xrmWd16UFALPXVfe8RqojLkIUGAnMcPeI/3bN7F4zKzSzwuJizeMo9d/o/Hac37UF3z7t+ujEOUV8d/ISnrv7PM6uQQeRmX3SwSTREy8hegv/bsoDbAY6lPm+fXjbGdz9GXcvcPeCnJyc8g4RqXfGjuzLyu0HeOnjUIPszx9u4Gdvr2DCPYPpmxf5MKhTPnt2e95fsSPq860mssBD1MwygYuBV8ts/hjoYWZdzKwRoZCdHER9IkE4dX30ibdW8P3XlvD0tDX85d4L6J4b2UQop8tsksrwvm14eY46mKIl1kOcJgCzgF5mVmRmnzez+8zsvjKHXQ+87e6fTPzo7ieA+4G3gGXAX919SSxrFYk3p66Pzl63m79+seoF+6rr1sEdeXH2Rk6eVAdTNOi2T5EE4+5c86vpfPua3nymhy6BVUdlt30G3pwXkbplZtw2uCMvqoMpKhSiIglodH4eM1bvjGh+1do6eryUkhMNbwE/hahIAmqWnsqIAXn89eNNVR8cJQ9MmMed42Y3uJVQFaIiCeq2wR2ZMHsjpVV0MB07Ucqv3lvF+Bnravxe01YWs2LbAdJSk/jh60tr/DrxSCEqkqD6tcskp1kaU1fuqPCY2et2c83/fcDCor38+l+rWb6t4ttRK3K89CSPv76U74w4i1+NGcT01TtrdD32eOlJvvPKIt5fUXG9QVCIiiSw2wZ34oUPzwy0fUeO8+ikRTw4YR4PX9mLP9x5Lt8Y3otHJi6q8sz1dC98uIHc5mkM79Oa5ump/OHOc/nFOysiuv306PFS7vvTHN5fUczfCuNrjKtCVCSBXTuwLXM27qFoz2EgNPzpjYVbueIXU0lOgre/ftEns0Tdcm4HGqUk8fys9dV+/T2HSvj1v1bz3Wv7fjKrfpdWTfnFzfnc/+LcT963MoeOneDu8R/TuFEyE780hGmrijl2ojTyHzZGFKIiCaxJoxSuy2/HXz7exJa9R/jCc4X877sr+e1tZ/PD6/rTvMwE0UlJxo9v6M+v/7WazXuPVOv1//fdlYwY0JZebT59p9VFPXP44sXduOf5ORwuOXOilVP2HT7O7X/8iI4tmvB/twyiTWY6PVs346O18TOJikJUJMHdOrgjz81cz4hffcCA9lm88eBQCjq3KPfYbjkZ3DWkM4/9fXGVU+ot37afNxZu5aHLe5a7/+4LO9Mvrznf/NuCcl9r58Fj3PL7Dzm7YzY/vqE/yeGlXC4/qzXvLtse4U/5b9v3H434kkRlFKIiCa5n62Z87fKe/O2+IXz18h6kpSRXevwXL+7G5j1HeH3h1gqPcXcef30pD1zaneymjco9xsz44fX92LbvKL/+1+pP7duy9wg3Pz2L4X1a818jzvrUAntX9Mnl3aXbazwv6n1/nsP01Ttr9NzyKERFhLuHhha0q45GKUn85LP9+cHrS9l7uKTcY95Zup0d+49xWxUrmKalJPO7O87hpdkbeXPxNgDW7zzEzU/PYsy5HXnoip5nrFDaLSeDRilJLN0a+UiBtcUHKdpzhAsjnIu1MgpREYnYoI7ZjOjflh+9seyMfcdOlPLDN5bx3ZF9SC1nvafT5TZL5+k7CvjOK4uYvGALtzzzIV8e1p17Lupa7vFmFmrSL418qNMr8zYzemBeuetQ1ZRCVERq5JtX9mLmml3MOK1pPG76enq2zohocpP+7TMZO6ovX//LfB69pje3Du5Y6fGX94n8uujJk86kuZu54ez2ET2vKgpREamRjLQUHr+uL99+ZRFHj4eGHO3Yf5Rnpq3hOyMiX1dy1MA85jx2BaPzq14JqKBTNpv2HGbrvuqNEgCYvX43zdJT6JPXPOLaKqMQFZEau7R3a/q3y+SX764C4Im3VnBzQQe6tGpao9fLbJxa9UFASnISw3rm8N6y6jfpJ80t4oazo79Um0JURGpl7Mi+vDxnEy/N3sj7K4u5/9LudfK+kTTpjx4v5a0l26t1lhsphaiI1EpOszT+86rePDJpEd8c3pNm6dU7m6yti3rmULh+D4fKWRX1dG8v3c7ADlm0bp4e9TpSov6KIpJwbjqnPempyYzo37bO3rN5eiqDOmbxwaqdXNWvTaXHTppbxGdj0JQHnYmKSBSYGaMG5n1yV1Fdqc7dSzsOHGXuhj0M71N50NaUQlRE6q3LzsplyvIdld7GOXn+Fob3bUPjRpXfiVVTClERqbfaZzcht3k68zbuqfCYiXM3x6RX/hSFqIjUa1eclcs7FTTpl23dz77DJZzfJXq3eZ5OISoi9drlfVrz7tLyQ/SVeZu5/ux2JMXwWq1CVETqtX55mRw4eoK1xQc/tf1E6Un+Pm8z1w+K7m2ep1OIiki9lpRkXHZW6zPuXpqxZhdtsxpXe3aqGr9/TF9dRKQOXNHnzOuisRwbWpZCVETqvSHdWrF0y372HArNb3rw2An+tXwH1w7Ii/l7K0RFpN5LT01mSLeWTAkvp/zPRVs5v2tLWlQwq340KURFpEEoOyHJpLmb66QpDwpREWkgLu2dywerdrJu5yGWb9vPJb1z6+R9FaIi0iC0ykijZ+tmPDJxISMGtK1ywb1oUYiKSINx+Vmt+Wjd7qgvAVIZTYUnIg3G1f3a8OHaXQzqkFVn76kQFZEGo3Orpjx393l1+p4xbc6b2Tgz22Fmiys5ZpiZzTezJWY2tcz2h8LbFpvZBDOL/pTUIiK1FOtrouOBqyraaWZZwG+BUe7eF7gpvL0d8CBQ4O79gGTglhjXKiISsZiGqLtPA3ZXcsitwCR33xg+vuzNrylAYzNLAZoAW2JWqIhIDQXdO98TyDaz981sjpn9PwB33wz8DNgIbAX2ufvbAdYpIlKuoEM0BTgHGAFcCTxmZj3NLBsYDXQB8oCmZnZ7eS9gZveaWaGZFRYXF9dV3SIiQPAhWgS85e6H3H0nMA0YCFwOrHP3Ync/DkwChpT3Au7+jLsXuHtBTk5OnRUuIgLBh+irwFAzSzGzJsBgYBmhZvz5ZtbEzAy4LLxdRCSuxHScqJlNAIYBrcysCBgLpAK4++/cfZmZvQksBE4Cf3D3xeHnvgzMBU4A84BnYlmriEhNmHvFS43WNwUFBV5YWBh0GSLSwJjZHHcvKG9f0M15EZF6TSEqIlILDao5b2bFwIYIn9YK2BmDcqKtvtQJ9adW1Rl99aXWSOvs5O7lDv9pUCFaE2ZWWNG1jnhSX+qE+lOr6oy++lJrNOtUc15EpBYUoiIitaAQrT/jT+tLnVB/alWd0Vdfao1anQl/TVREpDZ0JioiUgsJHaJmdpWZrTCz1Wb2SND1VMTM1pvZovAKAHF1S1Z5qxeYWQsze8fMVoX/zA6yxnBN5dX5PTPbHP5c55vZNUHWGK6pg5lNMbOl4ZUdvhreHlefaSV1xtVnambpZjbbzBaE6/x+eHsXM/so/Lv/FzNrVOP3SNTmvJklAyuBKwjNJvUxMMbdlwZaWDnMbD2hWf7jbvydmV0EHASeD69CgJn9FNjt7j8J/+eU7e7fisM6vwccdPefBVlbWWbWFmjr7nPNrBkwB7gO+Bxx9JlWUufNxNFnGp7AqKm7HzSzVGA68FXg64QmhH/JzH4HLHD3p2ryHol8JnoesNrd17p7CfASoTlMJQIVrF4wGngu/PVzhH65AlWNVRbigrtvdfe54a8PEJq9rB1x9plWUmdc8ZCD4W9Tww8HLgVeDm+v1eeZyCHaDthU5vsi4vAfQZgDb4dn/7836GKqobW7bw1/vQ1oHWQxVbjfzBaGm/uBX3Yoy8w6A4OAj4jjz/S0OiHOPlMzSzaz+cAO4B1gDbDX3U+ED6nV734ih2h9MtTdzwauBr4SbprWCx66XhSv14yeAroB+YSWofl5oNWUYWYZwETga+6+v+y+ePpMy6kz7j5Tdy9193ygPaEWaO9ovn4ih+hmoEOZ79uHt8Wd8JpTpxbye4XQP4R4tj18zezUtbMdVRwfCHffHv4FOwn8njj5XMPX7iYCL7j7pPDmuPtMy6szXj9TAHffC0wBLgCywotgQi1/9xM5RD8GeoR76RoRWpJ5csA1ncHMmoYv3GNmTYHhwOLKnxW4ycCd4a/vJLSCQdw5FUph1xMHn2u4I+SPwDJ3/0WZXXH1mVZUZ7x9pmaWY6Gl2TGzxoQ6kpcRCtMbw4fV6vNM2N55gPDwi18SWtd+nLv/KNiKzmRmXQmdfUJoJYIX46nOsqsXANsJrV7wd+CvQEdCs2rd7O6BdupUUOcwQs1OB9YDXyxz3TEQZjYU+ABYRGi1B4BvE7reGDefaSV1jiGOPlMzG0Co4yiZ0EnjX939B+Hfq5eAFoRWzrjd3Y/V6D0SOURFRGorkZvzIiK1phAVEakFhaiISC0oREVEakEhKiJSCwpRkQqY2TAzez3oOiS+KURFRGpBISr1npndHp4zcr6ZPR2ecOKgmf1veA7J98wsJ3xsvpl9GJ4g45VTE2SYWXczezc87+RcM+sWfvkMM3vZzJab2QvhO3VEPqEQlXrNzM4C/gO4MDzJRClwG9AUKHT3vsBUQncoATwPfMvdBxC62+bU9heAJ919IDCE0OQZEJqd6GtAH6ArcGGMfySpZ1KqPkQkrl0GnAN8HD5JbExoco6TwF/Cx/wZmGRmmUCWu08Nb38O+Ft4boJ27v4KgLsfBQi/3mx3Lwp/Px/oTGhiXxFAISr1nwHPufujn9po9thpx9X0/uay91OXot8ZOY2a81LfvQfcaGa58MlaRJ0I/ds+NUvPrcB0d98H7DGzz4S33wFMDc/MXmRm14VfI83MmtTlDyH1l/5XlXrN3Zea2X8Rmvk/CTgOfAU4BJwX3reD0HVTCE179rtwSK4F7gpvvwN42sx+EH6Nm+rwx5B6TLM4SYNkZgfdPSPoOqThU3NeRKQWdCYqIlILOhMVEakFhaiISC0oREVEakEhKiJSCwpREZFaUIiKiNTC/wdn2YE5l/9kaQAAAABJRU5ErkJggg==",
                        "text/plain": [
                            "<Figure size 360x360 with 1 Axes>"
                        ]
                    },
                    "metadata": {
                        "needs_background": "light"
                    },
                    "output_type": "display_data"
                }
            ],
            "source": [
                "# Model Fit\n",
                "with Timer() as train_time:\n",
                "    model.fit(Xtr)\n",
                "\n",
                "print(\"Took {:.2f} seconds for training.\".format(train_time.interval))\n",
                "\n",
                "# Plot the train RMSE as a function of the epochs\n",
                "line_graph(values=model.rmse_train, labels='train', x_name='epoch', y_name='rmse_train')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "During training, we can optionlly evauate the root mean squared error to have an idea of how the learning is proceeding. We would generally like to see this quantity decreasing as a function of the learning epochs. To visualise this choose `with_metrics = True` in the `RBM()` model function. \n",
                "\n",
                "Once the model has been trained, we can predict new ratings on the test set."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 9,
            "metadata": {
                "tags": [
                    "top_k"
                ]
            },
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Took 0.23 seconds for prediction.\n"
                    ]
                }
            ],
            "source": [
                "# number of top score elements to be recommended  \n",
                "K = 10\n",
                "\n",
                "# Model prediction on the test set Xtst.\n",
                "with Timer() as prediction_time:\n",
                "    top_k =  model.recommend_k_items(Xtst)\n",
                "\n",
                "print(\"Took {:.2f} seconds for prediction.\".format(prediction_time.interval))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "`top_k` returns the first K elements having the highest recommendation score. Here the recommendation score is evaluated by multiplying the predicted rating by its probability, i.e. the confidence the algorithm has about its output. So if we have two items both with predicted ratings 5, but one with probability 0.5 and the other 0.9, the latter will be considered more relevant. In order to inspect the prediction and use the evaluation metrics in this repository, we convert both top_k and Xtst to pandas dataframe format:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 10,
            "metadata": {},
            "outputs": [],
            "source": [
                "top_k_df = am.map_back_sparse(top_k, kind = 'prediction')\n",
                "test_df = am.map_back_sparse(Xtst, kind = 'ratings')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 11,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>userID</th>\n",
                            "      <th>movieID</th>\n",
                            "      <th>prediction</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>1</td>\n",
                            "      <td>100</td>\n",
                            "      <td>4.881824</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>1</td>\n",
                            "      <td>65</td>\n",
                            "      <td>4.822650</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>1</td>\n",
                            "      <td>129</td>\n",
                            "      <td>4.672100</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>1</td>\n",
                            "      <td>1104</td>\n",
                            "      <td>4.898961</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>1</td>\n",
                            "      <td>1123</td>\n",
                            "      <td>4.664860</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>5</th>\n",
                            "      <td>1</td>\n",
                            "      <td>1418</td>\n",
                            "      <td>4.611925</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>6</th>\n",
                            "      <td>1</td>\n",
                            "      <td>1427</td>\n",
                            "      <td>4.722356</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>7</th>\n",
                            "      <td>1</td>\n",
                            "      <td>1521</td>\n",
                            "      <td>4.738353</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>8</th>\n",
                            "      <td>1</td>\n",
                            "      <td>1583</td>\n",
                            "      <td>4.569103</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>9</th>\n",
                            "      <td>1</td>\n",
                            "      <td>1546</td>\n",
                            "      <td>4.890738</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "   userID  movieID  prediction\n",
                            "0       1      100    4.881824\n",
                            "1       1       65    4.822650\n",
                            "2       1      129    4.672100\n",
                            "3       1     1104    4.898961\n",
                            "4       1     1123    4.664860\n",
                            "5       1     1418    4.611925\n",
                            "6       1     1427    4.722356\n",
                            "7       1     1521    4.738353\n",
                            "8       1     1583    4.569103\n",
                            "9       1     1546    4.890738"
                        ]
                    },
                    "execution_count": 11,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "top_k_df.head(10)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## 4 Evaluation metrics \n",
                "\n",
                "Here we evaluate the performance of the algorithm using the metrics provided in the `PythonRankingEvaluation` class. Note that the following metrics take into account only the first K elements, therefore their value may be different from the one displayed from the `model.fit()` method. "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 12,
            "metadata": {
                "tags": [
                    "ranking"
                ]
            },
            "outputs": [],
            "source": [
                "def ranking_metrics(\n",
                "    data_size,\n",
                "    data_true,\n",
                "    data_pred,\n",
                "    K\n",
                "):\n",
                "\n",
                "    eval_map = map_at_k(data_true, data_pred, col_user=\"userID\", col_item=\"movieID\", \n",
                "                    col_rating=\"rating\", col_prediction=\"prediction\", \n",
                "                    relevancy_method=\"top_k\", k= K)\n",
                "\n",
                "    eval_ndcg = ndcg_at_k(data_true, data_pred, col_user=\"userID\", col_item=\"movieID\", \n",
                "                      col_rating=\"rating\", col_prediction=\"prediction\", \n",
                "                      relevancy_method=\"top_k\", k= K)\n",
                "\n",
                "    eval_precision = precision_at_k(data_true, data_pred, col_user=\"userID\", col_item=\"movieID\", \n",
                "                               col_rating=\"rating\", col_prediction=\"prediction\", \n",
                "                               relevancy_method=\"top_k\", k= K)\n",
                "\n",
                "    eval_recall = recall_at_k(data_true, data_pred, col_user=\"userID\", col_item=\"movieID\", \n",
                "                          col_rating=\"rating\", col_prediction=\"prediction\", \n",
                "                          relevancy_method=\"top_k\", k= K)\n",
                "\n",
                "    \n",
                "    df_result = pd.DataFrame(\n",
                "        {   \"Dataset\": data_size,\n",
                "            \"K\": K,\n",
                "            \"MAP\": eval_map,\n",
                "            \"nDCG@k\": eval_ndcg,\n",
                "            \"Precision@k\": eval_precision,\n",
                "            \"Recall@k\": eval_recall,\n",
                "        }, \n",
                "        index=[0]\n",
                "    )\n",
                "    \n",
                "    return df_result"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 13,
            "metadata": {
                "scrolled": true
            },
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>Dataset</th>\n",
                            "      <th>K</th>\n",
                            "      <th>MAP</th>\n",
                            "      <th>nDCG@k</th>\n",
                            "      <th>Precision@k</th>\n",
                            "      <th>Recall@k</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>mv 100k</td>\n",
                            "      <td>10</td>\n",
                            "      <td>0.140828</td>\n",
                            "      <td>0.411124</td>\n",
                            "      <td>0.336267</td>\n",
                            "      <td>0.212256</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "   Dataset   K       MAP    nDCG@k  Precision@k  Recall@k\n",
                            "0  mv 100k  10  0.140828  0.411124     0.336267  0.212256"
                        ]
                    },
                    "execution_count": 13,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "eval_100k = ranking_metrics(\n",
                "    data_size=\"mv 100k\",\n",
                "    data_true=test_df,\n",
                "    data_pred=top_k_df,\n",
                "    K=10\n",
                ")\n",
                "\n",
                "eval_100k"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 14,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "application/scrapbook.scrap.json+json": {
                            "data": 0.14082811192026132,
                            "encoder": "json",
                            "name": "map",
                            "version": 1
                        }
                    },
                    "metadata": {
                        "scrapbook": {
                            "data": true,
                            "display": false,
                            "name": "map"
                        }
                    },
                    "output_type": "display_data"
                },
                {
                    "data": {
                        "application/scrapbook.scrap.json+json": {
                            "data": 0.41112362614927883,
                            "encoder": "json",
                            "name": "ndcg",
                            "version": 1
                        }
                    },
                    "metadata": {
                        "scrapbook": {
                            "data": true,
                            "display": false,
                            "name": "ndcg"
                        }
                    },
                    "output_type": "display_data"
                },
                {
                    "data": {
                        "application/scrapbook.scrap.json+json": {
                            "data": 0.3362672322375398,
                            "encoder": "json",
                            "name": "precision",
                            "version": 1
                        }
                    },
                    "metadata": {
                        "scrapbook": {
                            "data": true,
                            "display": false,
                            "name": "precision"
                        }
                    },
                    "output_type": "display_data"
                },
                {
                    "data": {
                        "application/scrapbook.scrap.json+json": {
                            "data": 0.2122560190189148,
                            "encoder": "json",
                            "name": "recall",
                            "version": 1
                        }
                    },
                    "metadata": {
                        "scrapbook": {
                            "data": true,
                            "display": false,
                            "name": "recall"
                        }
                    },
                    "output_type": "display_data"
                }
            ],
            "source": [
                "# Record results for tests - ignore this cell\n",
                "store_metadata(\"map\", eval_100k['MAP'][0])\n",
                "store_metadata(\"ndcg\", eval_100k['nDCG@k'][0])\n",
                "store_metadata(\"precision\", eval_100k['Precision@k'][0])\n",
                "store_metadata(\"recall\", eval_100k['Recall@k'][0])"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## 5 Saving the model and Loading a pre-trained model\n",
                "Trained model checkpoint can be saved to a specified directory using the `save` function."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 15,
            "metadata": {},
            "outputs": [],
            "source": [
                "model.save(file_path='./models/rbm_model.ckpt')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Pre-trained RBM model can be loaded using the `load` function, which can be used to resume the training."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 16,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Initialize the model class\n",
                "model = RBM(\n",
                "    possible_ratings=np.setdiff1d(np.unique(Xtr), np.array([0])),\n",
                "    visible_units=Xtr.shape[1],\n",
                "    hidden_units=600,\n",
                "    training_epoch=30,\n",
                "    minibatch_size=60,\n",
                "    keep_prob=0.9,\n",
                "    with_metrics=True\n",
                ")\n",
                "\n",
                "# Load the model checkpoint\n",
                "model.load(file_path='./models/rbm_model.ckpt')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "celltoolbar": "Tags",
        "interpreter": {
            "hash": "67434505f7f08e5031eee7757e853265d2f43dd6b5963eb755a27835ec0e1503"
        },
        "kernel_info": {
            "name": "python3"
        },
        "kernelspec": {
            "display_name": "tf37",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.7.12"
        },
        "nteract": {
            "version": "0.12.3"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 4
}
